# Emma训练模块
import torch
from nn.attention import Attention
from nn.encoder import Encoder
from nn.decoder import Decoder
from nn.seq2seq import Seq2Seq
import torch.nn as nn
import torch.optim as optim
import config
from lib.data_center import DataCenter
import lib.util as util
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ",device)

def train(model, criterion, optimizer):
    print(1)


def main():
    dataset = DataCenter()
    datas = dataset.get_data_set()
    dicts = dataset.load_dict()
    dataset.set_dict(dicts)

    fanzxl = datas["fanzxl"]
    for i in range(len(fanzxl)):
        continue
    data1 = fk24[0]['sentence']
    data2 = fk24[1]['sentence']
    print(data1,data2)
    attn = Attention(config.ENC_HID_DIM, config.DEC_HID_DIM)
    enc = Encoder(config.INPUT_DIM, config.ENC_EMB_DIM, config.ENC_HID_DIM, config.DEC_HID_DIM, config.ENC_DROPOUT)
    dec = Decoder(config.OUTPUT_DIM, config.DEC_EMB_DIM, config.ENC_HID_DIM, config.DEC_HID_DIM, config.DEC_DROPOUT,
                  attn)
    # words = util.cut_word(data1)
    model = Seq2Seq(enc, dec, device).to(device)
    model.train()
    # 优化器 损失
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    list_data1 = dataset.transfer_word_2_number(data1)
    list_data2 = dataset.transfer_word_2_number(data2)
    src = torch.tensor([list_data1, [1]])

    trg = torch.tensor([list_data2, [1]])
    res = model(src, trg)
    print("res = ", res)


main()
